{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Updates on the Chunking Algorithm\n",
    "\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/sweepai/sweep/blob/main/notebooks/chunking.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>\n",
    "\n",
    "This notebook is for the blog on improvements to our chunking algorithm. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install tree_sitter_languages\n",
    "\n",
    "import re\n",
    "from tree_sitter_languages import get_language, get_parser\n",
    "\n",
    "language = get_language(\"python\")\n",
    "parser = get_parser(\"python\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Meet the Span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "from dataclasses import dataclass\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class Span:\n",
    "    # Represents a slice of a string\n",
    "    start: int = 0\n",
    "    end: int = 0\n",
    "\n",
    "    def __post_init__(self):\n",
    "        # If end is None, set it to start\n",
    "        if self.end is None:\n",
    "            self.end = self.start\n",
    "\n",
    "    def extract(self, s: str) -> str:\n",
    "        # Grab the corresponding substring of string s by bytes\n",
    "        return s[self.start : self.end]\n",
    "\n",
    "    def extract_lines(self, s: str) -> str:\n",
    "        # Grab the corresponding substring of string s by lines\n",
    "        return \"\\n\".join(s.splitlines()[self.start : self.end])\n",
    "\n",
    "    def __add__(self, other: Span | int) -> Span:\n",
    "        # e.g. Span(1, 2) + Span(2, 4) = Span(1, 4) (concatenation)\n",
    "        # There are no safety checks: Span(a, b) + Span(c, d) = Span(a, d)\n",
    "        # and there are no requirements for b = c.\n",
    "        if isinstance(other, int):\n",
    "            return Span(self.start + other, self.end + other)\n",
    "        elif isinstance(other, Span):\n",
    "            return Span(self.start, other.end)\n",
    "        else:\n",
    "            raise NotImplementedError()\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        # i.e. Span(a, b) = b - a\n",
    "        return self.end - self.start"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The example code we're gonna use in this guide will be from https://github.com/sweepai/sweep/blob/b267b613d4c706eaf959fe6789f11e9a856521d1/sweepai/handlers/on_check_suite.py, our old handler for parsing GitHub Action run logs at Sweep."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "\n",
    "example_file = \"https://raw.githubusercontent.com/sweepai/sweep/b267b613d4c706eaf959fe6789f11e9a856521d1/sweepai/handlers/on_check_suite.py\"\n",
    "python_code = requests.get(example_file).text"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's first visualize the syntax tree."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tree = parser.parse(python_code.encode(\"utf-8\"))\n",
    "\n",
    "\n",
    "def pretty_node(node):\n",
    "    return f\"{node.type}:{node.start_byte}-{node.end_byte}\"\n",
    "\n",
    "\n",
    "def print_tree(node, indent=\"\"):\n",
    "    if len(re.sub(\"\\s\", \"\", node.text.decode(\"utf-8\"))) < 100:\n",
    "        return\n",
    "    print(indent + pretty_node(node))\n",
    "    for child in node.children:\n",
    "        print_tree(child, indent=indent + \"  \")\n",
    "\n",
    "\n",
    "for child in tree.root_node.children:\n",
    "    print_tree(child)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that it doesn't actually line up:\n",
    "\n",
    "```python\n",
    "expression_statement:432-696\n",
    "  assignment:432-696\n",
    "    string:446-696\n",
    "function_definition:698-1502\n",
    "  block:777-1502\n",
    "    expression_statement:777-953\n",
    "      assignment:777-953\n",
    "        dictionary:787-953\n",
    "```\n",
    "\n",
    "Notice that the “expression_statement” ends on byte 696 and “function_definition” starts on byte 698, skipping a byte."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "def connect_chunks(chunks: list[Span]):\n",
    "    for prev, curr in zip(chunks[:-1], chunks[1:]):\n",
    "        prev.end = curr.start\n",
    "    return chunks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Coalescing\n",
    "\n",
    "Here was the algo presented in the last blog. Unfortunately it has some bugs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tree_sitter import Node\n",
    "from dataclasses import field\n",
    "\n",
    "\n",
    "def chunk_node(node: Node, text: str, MAX_CHARS: int = 600) -> list[str]:\n",
    "    chunks = []\n",
    "    current_chunk = \"\"\n",
    "    for child in node.children:\n",
    "        if child.end_byte - child.start_byte > MAX_CHARS:\n",
    "            chunks.append(current_chunk)\n",
    "            current_chunk = \"\"\n",
    "            chunks.extend(chunk_node(child, text, MAX_CHARS))\n",
    "        elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:\n",
    "            chunks.append(current_chunk)\n",
    "            current_chunk = text[child.start_byte : child.end_byte]\n",
    "        else:\n",
    "            current_chunk += text[child.start_byte : child.end_byte]\n",
    "    chunks.append(current_chunk)\n",
    "\n",
    "    return chunks\n",
    "\n",
    "\n",
    "for chunk in chunk_node(tree.root_node, python_code):\n",
    "    print(chunk + \"\\n\\n====================\\n\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here it is with the fixes by using the start_byte of the next node instead of the end_byte of the current node. \n",
    "\n",
    "I added a fake node at the end with start and end bytes equal to the end byte of the entire node. This is so that we don't need to rewrite the loop logic one last time for the last node. The purpose of MockNode is because the tree_sitter Node library doesn't have a constructor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class MockNode:\n",
    "    start_byte: int = 0\n",
    "    end_byte: int = 0\n",
    "    children: list[MockNode] = field(default_factory=list)\n",
    "\n",
    "\n",
    "def chunk_node(node: Node, text: str, MAX_CHARS: int = 600) -> list[str]:\n",
    "    chunks = []\n",
    "    current_chunk = \"\"\n",
    "    node_children = node.children + [MockNode(node.end_byte, node.end_byte)]\n",
    "\n",
    "    for child, next_child in zip(node_children[:-1], node_children[1:]):\n",
    "        if child.end_byte - child.start_byte > MAX_CHARS:\n",
    "            chunks.append(current_chunk)\n",
    "            current_chunk = \"\"\n",
    "            chunks.extend(chunk_node(child, text, MAX_CHARS))\n",
    "        elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:\n",
    "            chunks.append(current_chunk)\n",
    "            current_chunk = text[child.start_byte : next_child.start_byte]\n",
    "        else:\n",
    "            current_chunk += text[child.start_byte : next_child.start_byte]\n",
    "    chunks.append(current_chunk)\n",
    "\n",
    "    return chunks\n",
    "\n",
    "\n",
    "for chunk in chunk_node(tree.root_node, python_code):\n",
    "    print(chunk + \"\\n\\n====================\\n\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Firstly, using Span's we can clean up the code a bit. Like removing the MockNode altogether."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def chunk_node(\n",
    "    node: Node,\n",
    "    MAX_CHARS: int = 600,\n",
    ") -> list[Span]:\n",
    "    chunks: list[Span] = []\n",
    "    current_chunk: Span = Span(node.start_byte, node.start_byte)\n",
    "    node_children = node.children\n",
    "    for child in node_children:\n",
    "        if child.end_byte - child.start_byte > MAX_CHARS:\n",
    "            chunks.append(current_chunk)\n",
    "            current_chunk = Span(child.end_byte, child.end_byte)\n",
    "            chunks.extend(chunk_node(child, MAX_CHARS))\n",
    "        elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:\n",
    "            chunks.append(current_chunk)\n",
    "            current_chunk = Span(child.start_byte, child.end_byte)\n",
    "        else:\n",
    "            current_chunk += Span(child.start_byte, child.end_byte)\n",
    "    chunks.append(current_chunk)\n",
    "    return chunks\n",
    "\n",
    "\n",
    "for chunk in chunk_node(tree.root_node):\n",
    "    print(chunk)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Skipping Whitespace when Measuring Length\n",
    "\n",
    "Gives heavily indented code the same number of lines per code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [],
   "source": [
    "def char_len(s: str) -> int:  # old len function\n",
    "    return len(s)\n",
    "\n",
    "\n",
    "def non_whitespace_len(s: str) -> int:  # new len function\n",
    "    return len(re.sub(\"\\s\", \"\", s))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Coalescing Chunks\n",
    "\n",
    "Combining smaller chunks with larger ones."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def coalesce_chunks(\n",
    "    chunks: list[Span], source_code: str, coalesce: int = 50\n",
    ") -> list[Span]:\n",
    "    new_chunks = []\n",
    "    current_chunk = Span(0, 0)\n",
    "    for chunk in chunks:\n",
    "        current_chunk += chunk\n",
    "        if len(current_chunk) > coalesce and \"\\n\" in current_chunk.extract(source_code):\n",
    "            new_chunks.append(current_chunk)\n",
    "            current_chunk = Span(chunk.end, chunk.end)\n",
    "    if len(current_chunk) > 0:\n",
    "        new_chunks.append(current_chunk)\n",
    "    return new_chunks\n",
    "\n",
    "\n",
    "for chunk in coalesce_chunks(chunk_node(tree.root_node), python_code):\n",
    "    print(chunk.extract(python_code))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use Line Numbers\n",
    "\n",
    "Using line numbers instead of character indices. Works because Span is unit-agnostic."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_line_number(index: int, source_code: str) -> int:\n",
    "    total_chars = 0\n",
    "    for line_number, line in enumerate(source_code.splitlines(keepends=True), start=1):\n",
    "        total_chars += len(line)\n",
    "        if total_chars > index:\n",
    "            return line_number - 1\n",
    "    return line_number\n",
    "\n",
    "\n",
    "for i, chunk in enumerate(coalesce_chunks(chunk_node(tree.root_node), python_code)):\n",
    "    print(\n",
    "        f\"Chunk {i}: {get_line_number(chunk.start, python_code)}-{get_line_number(chunk.end, python_code)}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Final New Algorithm\n",
    "\n",
    "Putting it altogether (switched back to MAX_CHARS=1500):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tree_sitter import Tree\n",
    "\n",
    "\n",
    "def chunker(\n",
    "    tree: Tree,\n",
    "    source_code: bytes,\n",
    "    MAX_CHARS=512 * 3,\n",
    "    coalesce=50,  # Any chunk less than 50 characters long gets coalesced with the next chunk\n",
    ") -> list[Span]:\n",
    "\n",
    "    # 1. Recursively form chunks based on the last post (https://docs.sweep.dev/blogs/chunking-2m-files)\n",
    "    def chunk_node(node: Node) -> list[Span]:\n",
    "        chunks: list[Span] = []\n",
    "        current_chunk: Span = Span(node.start_byte, node.start_byte)\n",
    "        node_children = node.children\n",
    "        for child in node_children:\n",
    "            if child.end_byte - child.start_byte > MAX_CHARS:\n",
    "                chunks.append(current_chunk)\n",
    "                current_chunk = Span(child.end_byte, child.end_byte)\n",
    "                chunks.extend(chunk_node(child))\n",
    "            elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:\n",
    "                chunks.append(current_chunk)\n",
    "                current_chunk = Span(child.start_byte, child.end_byte)\n",
    "            else:\n",
    "                current_chunk += Span(child.start_byte, child.end_byte)\n",
    "        chunks.append(current_chunk)\n",
    "        return chunks\n",
    "\n",
    "    chunks = chunk_node(tree.root_node)\n",
    "\n",
    "    # 2. Filling in the gaps\n",
    "    for prev, curr in zip(chunks[:-1], chunks[1:]):\n",
    "        prev.end = curr.start\n",
    "    curr.start = tree.root_node.end_byte\n",
    "\n",
    "    # 3. Combining small chunks with bigger ones\n",
    "    new_chunks = []\n",
    "    current_chunk = Span(0, 0)\n",
    "    for chunk in chunks:\n",
    "        current_chunk += chunk\n",
    "        if non_whitespace_len(\n",
    "            current_chunk.extract(source_code)\n",
    "        ) > coalesce and \"\\n\" in current_chunk.extract(source_code):\n",
    "            new_chunks.append(current_chunk)\n",
    "            current_chunk = Span(chunk.end, chunk.end)\n",
    "    if len(current_chunk) > 0:\n",
    "        new_chunks.append(current_chunk)\n",
    "\n",
    "    # 4. Changing line numbers\n",
    "    line_chunks = [\n",
    "        Span(\n",
    "            get_line_number(chunk.start, source_code),\n",
    "            get_line_number(chunk.end, source_code),\n",
    "        )\n",
    "        for chunk in new_chunks\n",
    "    ]\n",
    "\n",
    "    # 5. Eliminating empty chunks\n",
    "    line_chunks = [chunk for chunk in line_chunks if len(chunk) > 0]\n",
    "\n",
    "    return line_chunks\n",
    "\n",
    "\n",
    "for chunk in chunker(tree, python_code):\n",
    "    print(chunk.extract_lines(python_code) + \"\\n\\n====================\\n\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.11.3 ('sweepai-u_CIt3kb-py3.11')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f22e56b0c638c2a35876232f2e2d6cfc31a0d98b7f3049980f1a4383610dba30"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
